// Package lex generates Go code for lexicons.
//
// (It is not a lexer.)
package lex

import (
	"bytes"
	"encoding/json"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"sort"
	"strings"

	"golang.org/x/tools/imports"
)

const (
	EncodingCBOR  = "application/cbor"
	EncodingJSON  = "application/json"
	EncodingJSONL = "application/jsonl"
	EncodingCAR   = "application/vnd.ipld.car"
	EncodingMP4   = "video/mp4"
	EncodingANY   = "*/*"
)

type outputType struct {
	Name      string
	Type      *TypeSchema
	NeedsCbor bool
	NeedsType bool
}

// Build total map of all types defined inside schemas.
// Return map from fully qualified type name to its *TypeSchema
func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef {
	out := make(map[string]*ExtDef)
	for _, s := range ss {
		for k, d := range s.Defs {
			d.defMap = out
			d.id = s.ID
			d.defName = k

			var pref string
			for _, pkg := range packages {
				if strings.HasPrefix(s.ID, pkg.Prefix) {
					pref = pkg.Prefix
					break
				}
			}
			d.prefix = pref

			n := s.ID
			if k != "main" {
				n = s.ID + "#" + k
			}
			out[n] = &ExtDef{
				Type: d,
			}
		}
	}
	return out
}

type ExtDef struct {
	Type *TypeSchema
}

// TODO: this method is necessary because in lexicon there is no way to know if
// a type needs to be marshaled with a "$type" field up front, you can only
// know for sure by seeing where the type is used.
func FixRecordReferences(schemas []*Schema, defmap map[string]*ExtDef, prefix string) {
	for _, s := range schemas {
		if !strings.HasPrefix(s.ID, prefix) {
			continue
		}

		tps := s.AllTypes(prefix, defmap)
		for _, t := range tps {
			if t.Type.Type == "record" {
				t.NeedsType = true
				t.Type.needsType = true
			}

			if t.Type.Type == "union" {
				for _, r := range t.Type.Refs {
					if r[0] == '#' {
						r = s.ID + r
					}

					if _, known := defmap[r]; known != true {
						panic(fmt.Sprintf("reference to unknown record type: %s", r))
					}

					if t.NeedsCbor {
						defmap[r].Type.needsCbor = true
					}
				}
			}
		}
	}
}

func printerf(w io.Writer) func(format string, args ...any) {
	return func(format string, args ...any) {
		fmt.Fprintf(w, format, args...)
	}
}

func GenCodeForSchema(pkg Package, reqcode bool, s *Schema, packages []Package, defmap map[string]*ExtDef) error {
	err := os.MkdirAll(pkg.Outdir, 0755)
	if err != nil {
		return fmt.Errorf("%s: could not mkdir, %w", pkg.Outdir, err)
	}
	fname := filepath.Join(pkg.Outdir, s.Name()+".go")
	buf := new(bytes.Buffer)
	pf := printerf(buf)

	s.prefix = pkg.Prefix
	for _, d := range s.Defs {
		d.prefix = pkg.Prefix
	}

	// Add the standard Go generated code header as recognized by GitHub, VS Code, etc.
	// See https://golang.org/s/generatedcode.
	pf("// Code generated by cmd/lexgen (see Makefile's lexgen); DO NOT EDIT.\n\n")

	pf("// Lexicon schema: %s\n\n", s.ID)

	pf("package %s\n\n", pkg.GoPackage)

	pf("import (\n")
	pf("\t\"context\"\n")
	pf("\t\"fmt\"\n")
	pf("\t\"encoding/json\"\n")
	pf("\tcbg \"github.com/whyrusleeping/cbor-gen\"\n")
	pf("\tlexutil \"github.com/bluesky-social/indigo/lex/util\"\n")
	for _, xpkg := range packages {
		if xpkg.Prefix != pkg.Prefix {
			pf("\t%s %q\n", importNameForPrefix(xpkg.Prefix), xpkg.Import)
		}
	}
	pf(")\n\n")

	tps := s.AllTypes(pkg.Prefix, defmap)

	if err := writeDecoderRegister(buf, tps); err != nil {
		return err
	}

	sort.Slice(tps, func(i, j int) bool {
		return tps[i].Name < tps[j].Name
	})
	for _, ot := range tps {
		fmt.Println("TYPE: ", ot.Name, ot.NeedsCbor, ot.NeedsType)
		if err := ot.Type.WriteType(ot.Name, buf); err != nil {
			return err
		}
	}

	// reqcode is always True
	if reqcode {
		name := nameFromID(s.ID, pkg.Prefix)
		main, ok := s.Defs["main"]
		if ok {
			if err := writeMethods(name, main, buf); err != nil {
				return err
			}
		}
	}

	if err := writeCodeFile(buf.Bytes(), fname); err != nil {
		return err
	}

	return nil
}

func writeDecoderRegister(w io.Writer, tps []outputType) error {
	var buf bytes.Buffer
	outf := printerf(&buf)

	for _, t := range tps {
		if t.Type.needsType && !strings.Contains(t.Name, "_") {
			id := t.Type.id
			if t.Type.defName != "" {
				id = id + "#" + t.Type.defName
			}
			if buf.Len() == 0 {
				outf("func init() {\n")
			}
			outf("lexutil.RegisterType(%q, &%s{})\n", id, t.Name)
		}
	}
	if buf.Len() == 0 {
		return nil
	}
	outf("}\n")
	_, err := w.Write(buf.Bytes())
	return err
}

func writeCodeFile(b []byte, fname string) error {
	fixed, err := imports.Process(fname, b, nil)
	if err != nil {
		werr := os.WriteFile("temp", b, 0664)
		if werr != nil {
			return werr
		}
		return fmt.Errorf("failed to format output of %q with goimports: %w (wrote failed file to ./temp)", fname, err)
	}

	if err := os.WriteFile(fname, fixed, 0664); err != nil {
		return err
	}

	return nil
}

func writeMethods(typename string, ts *TypeSchema, w io.Writer) error {
	switch ts.Type {
	case "token":
		n := ts.id
		if ts.defName != "main" {
			n += "#" + ts.defName
		}

		fmt.Fprintf(w, "const %s = %q\n", typename, n)
		return nil
	case "record":
		return nil
	case "query":
		return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
	case "procedure":
		if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" {
			return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
		} else if ts.Input.Schema.Type == "ref" {
			inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref)
			return ts.WriteRPC(w, typename, inputname)
		} else {
			return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type)
		}
	case "object", "string":
		return nil
	case "subscription":
		// TODO: should probably have some methods generated for this
		return nil
	default:
		return fmt.Errorf("unrecognized lexicon type %q", ts.Type)
	}
}

func nameFromID(id, prefix string) string {
	parts := strings.Split(strings.TrimPrefix(id, prefix), ".")
	var tname string
	for _, s := range parts {
		tname += strings.Title(s)
	}

	return tname

}

func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error {
	var keys []string
	for k := range m {
		keys = append(keys, k)
	}

	sort.Strings(keys)

	for _, k := range keys {
		if err := cb(k, m[k]); err != nil {
			return err
		}
	}
	return nil
}

func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error {
	buf := new(bytes.Buffer)

	if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil {
		return err
	}

	fname := filepath.Join(dir, "stubs.go")
	if err := writeCodeFile(buf.Bytes(), fname); err != nil {
		return err
	}

	if handlers {
		buf := new(bytes.Buffer)

		if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil {
			return err
		}

		fname := filepath.Join(dir, "handlers.go")
		if err := writeCodeFile(buf.Bytes(), fname); err != nil {
			return err
		}

	}

	return nil
}

func importNameForPrefix(prefix string) string {
	return strings.Join(strings.Split(prefix, "."), "")
}

func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
	pf := printerf(w)
	pf("package %s\n\n", pkg)
	pf("import (\n")
	pf("\t\"context\"\n")
	pf("\t\"fmt\"\n")
	pf("\t\"encoding/json\"\n")
	pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
	for k, v := range impmap {
		pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
	}
	pf(")\n\n")

	for _, s := range schemas {

		var prefix string
		for k := range impmap {
			if strings.HasPrefix(s.ID, k) {
				prefix = k
				break
			}
		}

		main, ok := s.Defs["main"]
		if !ok {
			fmt.Printf("WARNING: schema %q doesn't have a main def\n", s.ID)
			continue
		}

		if main.Type == "procedure" || main.Type == "query" {
			fname := idToTitle(s.ID)
			tname := nameFromID(s.ID, prefix)
			impname := importNameForPrefix(prefix)
			if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil {
				return err
			}
		}
	}

	return nil
}

func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
	pf := printerf(w)
	pf("package %s\n\n", pkg)
	pf("import (\n")
	pf("\t\"context\"\n")
	pf("\t\"fmt\"\n")
	pf("\t\"encoding/json\"\n")
	pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
	pf("\t\"github.com/labstack/echo/v4\"\n")

	var prefixes []string
	orderedMapIter[string](impmap, func(k, v string) error {
		prefixes = append(prefixes, k)
		pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
		return nil
	})
	pf(")\n\n")

	ssets := make(map[string][]*Schema)
	for _, s := range schemas {
		var pref string
		for _, p := range prefixes {
			if strings.HasPrefix(s.ID, p) {
				pref = p
				break
			}
		}
		if pref == "" {
			return fmt.Errorf("no matching prefix for schema %q (tried %s)", s.ID, prefixes)
		}

		ssets[pref] = append(ssets[pref], s)
	}

	for _, p := range prefixes {
		ss := ssets[p]

		pf("func (s *Server) RegisterHandlers%s(e *echo.Echo) error {\n", idToTitle(p))
		for _, s := range ss {

			main, ok := s.Defs["main"]
			if !ok {
				continue
			}

			var verb string
			switch main.Type {
			case "query":
				verb = "GET"
			case "procedure":
				verb = "POST"
			default:
				continue
			}

			pf("e.%s(\"/xrpc/%s\", s.Handle%s)\n", verb, s.ID, idToTitle(s.ID))
		}

		pf("return nil\n}\n\n")

		for _, s := range ss {

			var prefix string
			for k := range impmap {
				if strings.HasPrefix(s.ID, k) {
					prefix = k
					break
				}
			}

			main, ok := s.Defs["main"]
			if !ok {
				continue
			}

			if main.Type == "procedure" || main.Type == "query" {
				fname := idToTitle(s.ID)
				tname := nameFromID(s.ID, prefix)
				impname := importNameForPrefix(prefix)
				if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil {
					return fmt.Errorf("writing handler for %s: %w", s.ID, err)
				}
			}
		}
	}

	return nil
}

func idToTitle(id string) string {
	var fname string
	for _, p := range strings.Split(id, ".") {
		fname += strings.Title(p)
	}
	return fname
}

type Package struct {
	GoPackage string `json:"package"`
	Prefix    string `json:"prefix"`
	Outdir    string `json:"outdir"`
	Import    string `json:"import"`
}

// ParsePackages reads a json blob which should be an array of Package{} objects.
func ParsePackages(jsonBytes []byte) ([]Package, error) {
	var packages []Package
	err := json.Unmarshal(jsonBytes, &packages)
	if err != nil {
		return nil, err
	}
	return packages, nil
}

func Run(schemas []*Schema, externalSchemas []*Schema, packages []Package) error {
	defmap := BuildExtDefMap(append(schemas, externalSchemas...), packages)

	for _, pkg := range packages {
		prefix := pkg.Prefix
		FixRecordReferences(schemas, defmap, prefix)
	}

	for _, pkg := range packages {
		for _, s := range schemas {
			if !strings.HasPrefix(s.ID, pkg.Prefix) {
				continue
			}

			if err := GenCodeForSchema(pkg, true, s, packages, defmap); err != nil {
				return fmt.Errorf("failed to process schema %q: %w", s.path, err)
			}
		}
	}
	return nil
}
